Recommendation Model with Approximate Item Matching

This notebook shows how to train a simple Neural Collaborative Filtering model for recommeding movies to users. We also show how learnt movie embeddings are stored in an appoximate similarity matching index, using Spotify's Annoy library, so that we can quickly find and recommend the most relevant movies to a given customer. We show how this index to search for similar movies.

In essense, this tutorial works as follows:

  1. Download the movielens dataset.
  2. Train a simple Neural Collaborative Model using TensorFlow custom estimator.
  3. Extract the learnt movie embeddings.
  4. Build an approximate similarity matching index for the movie embeddings.
  5. Export the trained model, which receives a user Id, and output the user embedding.

The recommendation is served as follows:

  1. Receives a user Id
  2. Get the user embedding from the exported model
  3. Find the similar movie embeddings to the user embedding in the index
  4. Return the movie Ids of these embeddings to recommend


!pip install annoy

import math
import os
import pandas as pd
import numpy as np
from datetime import datetime

import tensorflow as tf
from tensorflow import data

print "TensorFlow : {}".format(tf.__version__)

SEED = 19831060

TensorFlow : 1.12.0

1. Download Data

! wget -P data/
! unzip data/ -d data/
TRAIN_DATA_FILE = os.path.join(DATA_DIR, 'ml-latest-small/ratings.csv')

--2019-02-17 16:29:41--
Connecting to||:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 978202 (955K) [application/zip]
Saving to: ‘data/’ 100%[===================>] 955.28K  1.19MB/s    in 0.8s    

2019-02-17 16:29:42 (1.19 MB/s) - ‘data/’ saved [978202/978202]

Archive:  data/
   creating: data/ml-latest-small/
  inflating: data/ml-latest-small/links.csv  
  inflating: data/ml-latest-small/tags.csv  
  inflating: data/ml-latest-small/ratings.csv  
  inflating: data/ml-latest-small/README.txt  
  inflating: data/ml-latest-small/movies.csv  

ratings_data = pd.read_csv(TRAIN_DATA_FILE)

userId movieId rating timestamp
count 100836.000000 100836.000000 100836.000000 1.008360e+05
mean 326.127564 19435.295718 3.501557 1.205946e+09
std 182.618491 35530.987199 1.042529 2.162610e+08
min 1.000000 1.000000 0.500000 8.281246e+08
25% 177.000000 1199.000000 3.000000 1.019124e+09
50% 325.000000 2991.000000 3.500000 1.186087e+09
75% 477.000000 8122.000000 4.000000 1.435994e+09
max 610.000000 193609.000000 5.000000 1.537799e+09

userId movieId rating timestamp
0 1 1 4.0 964982703
1 1 3 4.0 964981247
2 1 6 4.0 964982224
3 1 47 5.0 964983815
4 1 50 5.0 964982931

movies_data = pd.read_csv(os.path.join(DATA_DIR, 'ml-latest-small/movies.csv'))

movieId title genres
0 1 Toy Story (1995) Adventure|Animation|Children|Comedy|Fantasy
1 2 Jumanji (1995) Adventure|Children|Fantasy
2 3 Grumpier Old Men (1995) Comedy|Romance
3 4 Waiting to Exhale (1995) Comedy|Drama|Romance
4 5 Father of the Bride Part II (1995) Comedy

2. Build the TensorFlow Model

2.1 Define Metadata

HEADER = ['userId', 'movieId', 'rating', 'timestamp']
HEADER_DEFAULTS = [0, 0, 0.0, 0]
TARGET_NAME = 'rating'
num_users = ratings_data.userId.max()
num_movies = movies_data.movieId.max()

2.2 Define Data Input Function

def make_input_fn(file_pattern, batch_size, num_epochs, 
    def _input_fn():
        dataset =
            shuffle= (mode==tf.estimator.ModeKeys.TRAIN)
        return dataset
    return _input_fn

2.3 Create Feature Columns

def create_feature_columns(embedding_size):
    feature_columns = []
                'userId', num_buckets=num_users + 1), 
                'movieId', num_buckets=num_movies + 1),
    return feature_columns

2.4 Define Model Function

def model_fn(features, labels, mode, params):
    feature_columns = create_feature_columns(params.embedding_size)
    user_layer = tf.feature_column.input_layer(
        features={'userId': features['userId']}, feature_columns=[feature_columns[0]])
    if mode != tf.estimator.ModeKeys.PREDICT:
        movie_layer = tf.feature_column.input_layer(
            features={'movieId': features['movieId']}, feature_columns=[feature_columns[1]])
        dot_product = tf.keras.layers.Dot(axes=1)([user_layer, movie_layer])
        logits = tf.clip_by_value(clip_value_min=0, clip_value_max=5, t=dot_product)

    predictions = None
    export_outputs = None
    loss = None
    train_op = None

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {'user_embedding': user_layer}
        export_outputs = {'predictions': tf.estimator.export.PredictOutput(predictions)}
        loss = tf.losses.mean_squared_error(labels, tf.squeeze(logits))
            loss=loss, global_step=tf.train.get_global_step())

    loss = tf.losses.mean_squared_error(labels, tf.squeeze(logits))
    return tf.estimator.EstimatorSpec(

2.5 Create Estimator

def create_estimator(params, run_config):
    estimator = tf.estimator.Estimator(
    return estimator

2.6 Define Experiment

def train_and_evaluate_experiment(params, run_config):
    # TrainSpec ####################################
    train_input_fn = make_input_fn(
    train_spec = tf.estimator.TrainSpec(
        input_fn = train_input_fn,
    # EvalSpec ####################################
    eval_input_fn = make_input_fn(

    eval_spec = tf.estimator.EvalSpec(
        input_fn = eval_input_fn,

    if tf.gfile.Exists(run_config.model_dir):
        print("Removing previous artefacts...")
    print ''
    estimator = create_estimator(params, run_config)
    print ''
    time_start = datetime.utcnow() 
    print("Experiment started at {}".format(time_start.strftime("%H:%M:%S")))


    time_end = datetime.utcnow() 
    print("Experiment finished at {}".format(time_end.strftime("%H:%M:%S")))
    time_elapsed = time_end - time_start
    print("Experiment elapsed time: {} seconds".format(time_elapsed.total_seconds()))
    return estimator

2.7 Run Experiment with Parameters

MODELS_LOCATION = 'models/movieles'
MODEL_NAME = 'recommender_01'
model_dir = os.path.join(MODELS_LOCATION, MODEL_NAME)

params  =

run_config = tf.estimator.RunConfig(

estimator = train_and_evaluate_experiment(params, run_config)

Experiment finished at 19:14:27

Experiment elapsed time: 7.371008 seconds

3. Extract Movie Embeddings

def find_embedding_tensor():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(os.path.join(model_dir, 'model.ckpt-100000.meta'))
        saver.restore(sess, os.path.join(model_dir, 'model.ckpt-100000'))
        graph = tf.get_default_graph()
        trainable_tensors = map(str, graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
        for tensor in set(trainable_tensors):
            print tensor

def extract_embeddings():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(os.path.join(model_dir, 'model.ckpt-100000.meta'))
        saver.restore(sess, os.path.join(model_dir, 'model.ckpt-100000'))
        graph = tf.get_default_graph()
        weights_tensor = graph.get_tensor_by_name('input_layer_1/movieId_embedding/embedding_weights:0')
        weights = np.array(
    embeddings = {}
    for i in range(weights.shape[0]):
        embeddings[i] = weights[i]
    return embeddings

embeddings = extract_embeddings()

4. Build Annoy Index

from annoy import AnnoyIndex

def build_embeddings_index(num_trees):
    total_items = 0
    annoy_index = AnnoyIndex(params.embedding_size, metric='angular')
    for item_id in embeddings.keys():
        annoy_index.add_item(item_id, embeddings[item_id])
        total_items += 1
    print "{} items where added to the index".format(total_items)
    print "Index is built"
    return annoy_index

index = build_embeddings_index(100)

frequent_movie_ids = list(ratings_data.movieId.value_counts().index[:15])

In [113]:

movieId title genres
0 1 Toy Story (1995) Adventure|Animation|Children|Comedy|Fantasy
46 50 Usual Suspects, The (1995) Crime|Mystery|Thriller
97 110 Braveheart (1995) Action|Drama|War
224 260 Star Wars: Episode IV - A New Hope (1977) Action|Adventure|Sci-Fi
257 296 Pulp Fiction (1994) Comedy|Crime|Drama|Thriller
277 318 Shawshank Redemption, The (1994) Crime|Drama
314 356 Forrest Gump (1994) Comedy|Drama|Romance|War
418 480 Jurassic Park (1993) Action|Adventure|Sci-Fi|Thriller
461 527 Schindler's List (1993) Drama|War
507 589 Terminator 2: Judgment Day (1991) Action|Sci-Fi
510 593 Silence of the Lambs, The (1991) Crime|Horror|Thriller
898 1196 Star Wars: Episode V - The Empire Strikes Back... Action|Adventure|Sci-Fi
1939 2571 Matrix, The (1999) Action|Sci-Fi|Thriller
2145 2858 American Beauty (1999) Drama|Romance
2226 2959 Fight Club (1999) Action|Crime|Drama|Thriller

def get_similar_movies(movie_id, num_matches=5):
    similar_movie_ids = index.get_nns_by_item(
        movie_id, num_matches, search_k=-1, include_distances=False)
    similar_movies = movies_data[movies_data['movieId'].isin(similar_movie_ids)].title
    return similar_movies

for movie_id in frequent_movie_ids:
    movie_title = movies_data[movies_data['movieId'] == movie_id].title.values[0]
    print "Movie: {}".format(movie_title)
    similar_movies = get_similar_movies(movie_id)
    print "Similar Movies:"
    print similar_movies
    print "--------------------------------------"

Movie: Forrest Gump (1994)
Similar Movies:
55                  Mr. Holland's Opus (1995)
314                       Forrest Gump (1994)
1956    Open Your Eyes (Abre los ojos) (1997)
2372                   Green Mile, The (1999)
4867                    50 First Dates (2004)
Name: title, dtype: object
Movie: Shawshank Redemption, The (1994)
Similar Movies:
277          Shawshank Redemption, The (1994)
955                          Duck Soup (1933)
1956    Open Your Eyes (Abre los ojos) (1997)
2462              Boondock Saints, The (2000)
7466                King's Speech, The (2010)
Name: title, dtype: object
Movie: Pulp Fiction (1994)
Similar Movies:
257                                   Pulp Fiction (1994)
2226                                    Fight Club (1999)
2250                      Who Framed Roger Rabbit? (1988)
6310    Borat: Cultural Learnings of America for Make ...
Name: title, dtype: object
Movie: Silence of the Lambs, The (1991)
Similar Movies:
510     Silence of the Lambs, The (1991)
941                         Glory (1989)
1032                    Cape Fear (1962)
2078             Sixth Sense, The (1999)
5374             Incredibles, The (2004)
Name: title, dtype: object
Movie: Matrix, The (1999)
Similar Movies:
418                                  Jurassic Park (1993)
509                                         Batman (1989)
793                                       Die Hard (1988)
911     Star Wars: Episode VI - Return of the Jedi (1983)
1939                                   Matrix, The (1999)
Name: title, dtype: object
Movie: Star Wars: Episode IV - A New Hope (1977)
Similar Movies:
224             Star Wars: Episode IV - A New Hope (1977)
898     Star Wars: Episode V - The Empire Strikes Back...
911     Star Wars: Episode VI - Return of the Jedi (1983)
969                             Back to the Future (1985)
2097                                     Airplane! (1980)
Name: title, dtype: object
Movie: Jurassic Park (1993)
Similar Movies:
63          Fair Game (1995)
253          Outbreak (1995)
418     Jurassic Park (1993)
793          Die Hard (1988)
2608             Hook (1991)
Name: title, dtype: object
Movie: Braveheart (1995)
Similar Movies:
31      Twelve Monkeys (a.k.a. 12 Monkeys) (1995)
97                              Braveheart (1995)
337                              True Lies (1994)
1267                      Truman Show, The (1998)
1803      First Blood (Rambo: First Blood) (1982)
Name: title, dtype: object
Movie: Terminator 2: Judgment Day (1991)
Similar Movies:
474                         Blade Runner (1982)
507           Terminator 2: Judgment Day (1991)
939                      Terminator, The (1984)
1469                         Poltergeist (1982)
1803    First Blood (Rambo: First Blood) (1982)
Name: title, dtype: object
Movie: Schindler's List (1993)
Similar Movies:
13                                 Nixon (1995)
461                     Schindler's List (1993)
561     Some Folks Call It a Sling Blade (1993)
922              Godfather: Part II, The (1974)
2110                  Christmas Story, A (1983)
Name: title, dtype: object
Movie: Fight Club (1999)
Similar Movies:
596     Ghost in the Shell (Kôkaku kidôtai) (1995)
1706                                   Antz (1998)
1734                     American History X (1998)
2226                             Fight Club (1999)
6676                              In Bruges (2008)
Name: title, dtype: object
Movie: Toy Story (1995)
Similar Movies:
0                             Toy Story (1995)
506                             Aladdin (1992)
2436    Hand That Rocks the Cradle, The (1992)
3568                     Monsters, Inc. (2001)
7355                        Toy Story 3 (2010)
Name: title, dtype: object
Movie: Star Wars: Episode V - The Empire Strikes Back (1980)
Similar Movies:
224             Star Wars: Episode IV - A New Hope (1977)
826                              Dial M for Murder (1954)
898     Star Wars: Episode V - The Empire Strikes Back...
911     Star Wars: Episode VI - Return of the Jedi (1983)
4711                                  Quick Change (1990)
Name: title, dtype: object
Movie: American Beauty (1999)
Similar Movies:
962         Deer Hunter, The (1978)
1290    Sweet Hereafter, The (1997)
2145         American Beauty (1999)
2226              Fight Club (1999)
3141                 Memento (2000)
Name: title, dtype: object
Movie: Usual Suspects, The (1995)
Similar Movies:
46               Usual Suspects, The (1995)
277        Shawshank Redemption, The (1994)
2462            Boondock Saints, The (2000)
3234    A.I. Artificial Intelligence (2001)
5917                   Batman Begins (2005)
Name: title, dtype: object

5. Export the Model

This needed to receive a userId and produce the embedding for the user.

def make_serving_input_receiver_fn():
    return tf.estimator.export.build_raw_serving_input_receiver_fn(
        {'userId': tf.placeholder(shape=[None], dtype=tf.int32)}

export_dir = os.path.join(model_dir, 'export')

if tf.gfile.Exists(export_dir):

import os

export_dir = os.path.join(model_dir, "export")
saved_model_dir = os.path.join(
    export_dir, [f for f in os.listdir(export_dir) if f.isdigit()][0])


predictor_fn = tf.contrib.predictor.from_saved_model(
    export_dir = saved_model_dir,

output = predictor_fn({'userId': [1]})

{u'user_embedding': array([[-0.04079459, -0.06252338,  0.01964831, -0.03159623,  0.01765972,
         0.00015648,  0.0686218 ,  0.01872032,  0.04238764,  0.03700782,
         0.00166043,  0.00917281,  0.01879879, -0.01652114, -0.02870869,
         0.00668285]], dtype=float32)}

Serve Movie Recommendations to a User

In [190]:
def recommend_new_movies(userId, num_recommendations=5):
    watched_movie_ids = list(ratings_data[ratings_data['userId']==userId]['movieId'])
    user_emebding = predictor_fn({'userId': [userId]})['user_embedding'][0]
    similar_movie_ids = index.get_nns_by_vector(
        user_emebding, num_recommendations + len(watched_movie_ids), search_k=-1, include_distances=False)
    recommended_movie_ids = set(similar_movie_ids) - set(watched_movie_ids)
    similar_movies = movies_data[movies_data['movieId'].isin(recommended_movie_ids)].title
    return similar_movies

In [191]:
frequent_user_ids = list((ratings_data.userId.value_counts().index[-350:]))[:5] 
print recommend_movies(418)

3857            Dangerous Lives of Altar Boys, The (2002)
8140    Wolf Children (Okami kodomo no ame to yuki) (2...
8429                                          Chef (2014)
9610                                     Cage Dive (2017)
9729                                         Bunny (1998)
Name: title, dtype: object


Author: Khalid Salama

Disclaimer: This is not an official Google product. This sample code provided for an educational purpose.

